{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "34qVD_ntSKLF"
   },
   "source": [
    "# Custom Training Logic with Lightning Integration and Lightning Hooks\n",
    "\n",
    "In this example, we showcase the ability for the user to define own training logic and easily integrate into Lightning workflow using a variety of Lightning hooks. \n",
    "A reference to these hooks is provided here: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks. \n",
    "\n",
    "The first part of this notebooks is equivalent to the basics tutorial from before, so we will speed ahead through that first: \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OCn3zpaIqgMc"
   },
   "source": [
    "## NeuroMANCER and Dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qzy5Wot5k2Gf"
   },
   "source": [
    "### Install (Colab only)\n",
    "Skip this step when running locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "X_3EvkSz0Fnz",
    "outputId": "23c06f6b-ab48-4763-c43c-40a325cacf87"
   },
   "outputs": [],
   "source": [
    "!pip install neuromancer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LWyvndXlz0Fv"
   },
   "source": [
    "### Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "POL27EJZxJmI"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import neuromancer.slim as slim\n",
    "import matplotlib.pyplot as plt\n",
    "import lightning.pytorch as pl \n",
    "\n",
    "from neuromancer.trainer import LitTrainer\n",
    "from neuromancer.problem import Problem\n",
    "from neuromancer.constraint import variable\n",
    "from neuromancer.dataset import DictDataset\n",
    "from neuromancer.loss import PenaltyLoss\n",
    "from neuromancer.modules import blocks\n",
    "from neuromancer.system import Node\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem formulation\n",
    "\n",
    "In this example we will solve parametric constrained [Rosenbrock problem](https://en.wikipedia.org/wiki/Rosenbrock_function):\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "&\\text{minimize } &&  (1-x)^2 + a(y-x^2)^2\\\\\n",
    "&\\text{subject to} && \\left(\\frac{p}{2}\\right)^2 \\le x^2 + y^2 \\le p^2\\\\\n",
    "& && x \\ge y\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "with parameters $p, a$ and decision variables $x, y$.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_WH7o7Wu1epw"
   },
   "source": [
    "### Lightning Dataset\n",
    "\n",
    "We constructy the dataset by sampling the parametric space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "_r6p2p6myHAh"
   },
   "outputs": [],
   "source": [
    "data_seed = 408  # random seed used for simulated data\n",
    "np.random.seed(data_seed)\n",
    "torch.manual_seed(data_seed)\n",
    "nsim = 5000  # number of datapoints: increase sample density for more robust results\n",
    "\n",
    "# create dictionaries with sampled datapoints with uniform distribution\n",
    "a_low, a_high, p_low, p_high = 0.2, 1.2, 0.5, 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "Nu58M-8JyHy6"
   },
   "outputs": [],
   "source": [
    "\n",
    "def data_setup_function(nsim, a_low, a_high, p_low, p_high): \n",
    "\n",
    "    \n",
    "    samples_train = {\"a\": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),\n",
    "                    \"p\": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}\n",
    "    samples_dev = {\"a\": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),\n",
    "                \"p\": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}\n",
    "    samples_test = {\"a\": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),\n",
    "                \"p\": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}\n",
    "    # create named dictionary datasets\n",
    "    train_data = DictDataset(samples_train, name='train')\n",
    "    dev_data = DictDataset(samples_dev, name='dev')\n",
    "    test_data = DictDataset(samples_test, name='test')\n",
    "\n",
    "    batch_size = 64\n",
    "\n",
    "    # Return the dict datasets in train, dev, test order, followed by batch_size \n",
    "    return train_data, dev_data, test_data, batch_size \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the **Problem()**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "Ta_I_pjyyLzf"
   },
   "outputs": [],
   "source": [
    "# define neural architecture for the trainable solution map\n",
    "func = blocks.MLP(insize=2, outsize=2,\n",
    "                bias=True,\n",
    "                linear_map=slim.maps['linear'],\n",
    "                nonlin=nn.ReLU,\n",
    "                hsizes=[80] * 4)\n",
    "# wrap neural net into symbolic representation of the solution map via the Node class: sol_map(xi) -> x\n",
    "sol_map = Node(func, ['a', 'p'], ['x'], name='map')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lxj77EFj7EO-"
   },
   "source": [
    "## Objective and Constraints in NeuroMANCER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "bcoVjphjyPp9"
   },
   "outputs": [],
   "source": [
    "# define decision variables\n",
    "x1 = variable(\"x\")[:, [0]]\n",
    "x2 = variable(\"x\")[:, [1]]\n",
    "# problem parameters sampled in the dataset\n",
    "p = variable('p')\n",
    "a = variable('a')\n",
    "\n",
    "# objective function\n",
    "f = (1-x1)**2 + a*(x2-x1**2)**2\n",
    "obj = f.minimize(weight=1.0, name='obj')\n",
    "\n",
    "# constraints\n",
    "Q_con = 100.  # constraint penalty weights\n",
    "con_1 = Q_con*(x1 >= x2)\n",
    "con_2 = Q_con*((p/2)**2 <= x1**2+x2**2)\n",
    "con_3 = Q_con*(x1**2+x2**2 <= p**2)\n",
    "con_1.name = 'c1'\n",
    "con_2.name = 'c2'\n",
    "con_3.name = 'c3'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 496
    },
    "id": "n7VPa9Wc8JRB",
    "outputId": "0da17c45-6370-4f46-f626-bd5686b94bfc"
   },
   "outputs": [],
   "source": [
    "# constrained optimization problem construction\n",
    "objectives = [obj]\n",
    "constraints = [con_1, con_2, con_3]\n",
    "components = [sol_map]\n",
    "\n",
    "# create penalty method loss function\n",
    "loss = PenaltyLoss(objectives, constraints)\n",
    "# construct constrained optimization problem\n",
    "problem = Problem(components, loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lightning Hooks: \n",
    "\n",
    "Lightning hooks are modular \"lego\" blocks that define the training process of the LightningModule. Recall that the Lightning trainer fits a `LightningModule` to a `LightningDataModule`. For user-simplicity, the LightningModule and LightningModule are abstracted away; the user only need to interact with the LitTrainer. However we provide the capability of more fine-grained control over the training process by interacting with these hooks. \n",
    "\n",
    "Let's begin with the simplest hook: the `training_step`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Training Logic\n",
    "Training within PyTorch Lightning framework is defined by a `training_step` function, which defines the logic going from a data batch to loss. For example, the default training_step used is shown below (other extraneous details removed for simplicity). Here, we get the problem output for the given batch and return the loss associated with that output.\n",
    "\n",
    "```\n",
    "def training_step(self, batch):\n",
    "    output = self.problem(batch)\n",
    "    loss = output[self.train_metric]\n",
    "    return loss\n",
    "```\n",
    "\n",
    "Notice how easy this is, there is no need to call `optimizer.zero_grad()`, etc.; there is no PyTorch boilerplate. All the user needs to do is define how the loss should be generated during training. \n",
    "\n",
    "While rare, there may be instances where the user might want to define their own training logic. Potential cases include test-time data augmentation (e.g. operations on/w.r.t the data rollout), other domain augmentations, or modifications to how the output and/or loss is handled. \n",
    "\n",
    "The user can pass in their own \"training_step\" by supplying an equivalent function handler to the \"custom_training_step\" keyword of LitTrainer, for example: \n",
    "\n",
    "```\n",
    "def custom_training_step(model, batch): \n",
    "    output = model.problem(batch)\n",
    "    Q_con = 1\n",
    "    if model.current_epoch > 1: \n",
    "        Q_con = 1/10000\n",
    "    loss = Q_con*(output[model.train_metric])\n",
    "    return loss\n",
    "```\n",
    "\n",
    "The signature of this function should be `custom_training_step(model, batch)` where model is a Neuromancer Problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_training_step(model, batch): \n",
    "    output = model.problem(batch)\n",
    "    Q_con = 1\n",
    "    if model.current_epoch > 1: \n",
    "        Q_con = 1/10000    \n",
    "    loss = Q_con*(output[model.train_metric])\n",
    "    return loss\n",
    "\n",
    "lit_trainer = LitTrainer(epochs=10, accelerator='cpu', patience=3, custom_training_step=custom_training_step)\n",
    "lit_trainer.fit(problem=problem, data_setup_function=data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is another example of a dummy custom_training_step. Here we want to add the loss of the previous batch and accumulate into the \"current\" loss. (Again this is a dummy example and not necessarily propel ML techniques). Any sort of variables, such as \"past_loss\" can be defined by setting them as attributes of \"model\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_training_step(model, batch): \n",
    "    with torch.no_grad(): \n",
    "        if model.current_epoch == 0: \n",
    "            model.past_loss = 0\n",
    "    \n",
    "    output = model.problem(batch)\n",
    "    loss = (output[model.train_metric]) + 0.5*model.past_loss\n",
    "    model.past_loss = loss.item()\n",
    "    return loss\n",
    "\n",
    "lit_trainer = LitTrainer(epochs=100, accelerator='cpu', patience=3, custom_training_step=custom_training_step)\n",
    "lit_trainer.fit(problem=problem, data_setup_function=data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More Custom Hooks Via PyTorch Lightning\n",
    "\n",
    "The **custom_training_step** discussed above is one example of the many hooks created by PyTorch Lightning. These hooks are special methods in the `LightningModule` class that allow for customization and fine-tuning of the training, validation, and testing processes. These hooks are invoked at specific points during the training lifecycle and enable users to inject custom logic. For instance, hooks like `training_step`, `validation_step`, and `test_step` handle the core logic for processing batches during different stages. Hooks like `on_epoch_start`, `on_epoch_end`, `on_batch_start`, and `on_batch_end` allow for actions at the start and end of epochs and batches. Other hooks such as `on_train_start`, `on_train_end`, `on_validation_start`, and `on_validation_end` provide entry and exit points for the training and validation phases, allowing for setup, teardown, logging, and other custom operations. These hooks provide a structured and clean way to extend and customize the training workflow in PyTorch Lightning.\n",
    "\n",
    "For a list of all available hooks please refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks\n",
    "\n",
    "We can visualize this hook interface as follows: \n",
    "\n",
    "<img src=\"../../figs/hooks.png\" width=\"600\">  \n",
    "\n",
    "\n",
    "### Integrating More Hooks into NeuroMANCER\n",
    "\n",
    "We implement the hook(s) based off their respective signature (please refer to the API at the link above to find that) and pass in a hook dictionary to the LitTrainer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define custom hooks\n",
    "def custom_train_epoch_end(self): \n",
    "    if not hasattr(self, 'train_loss_epoch_history') or self.train_loss_epoch_history is None: #define a list to store training loss\n",
    "        self.train_loss_epoch_history = []\n",
    "    \n",
    "    # get the epoch average train loss. `training_step_outputs` is a list already created to store loss per batch within the training epoch\n",
    "    epoch_average = torch.stack(self.training_step_outputs).mean()\n",
    "    self.train_loss_epoch_history.append(epoch_average)\n",
    "\n",
    "# Do similar thing for validation loss\n",
    "def custom_validation_epoch_end(self):\n",
    "    if not hasattr(self, 'val_loss_epoch_history') or self.val_loss_epoch_history is None:\n",
    "        self.val_loss_epoch_history = []\n",
    "\n",
    "    epoch_average = torch.stack(self.validation_step_outputs).mean()\n",
    "    self.val_loss_epoch_history.append(epoch_average)\n",
    "\n",
    "# Do something when training starts\n",
    "def on_train_start(self): \n",
    "    print(\"HELLO WORLD\")\n",
    "\n",
    "# optimizer with scheduler\n",
    "def configure_optimizers(self):\n",
    "    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)\n",
    "    return [optimizer], [scheduler]\n",
    "\n",
    "# Create a custom hooks dictionary\n",
    "custom_hooks = {\n",
    "    'on_train_epoch_end': custom_train_epoch_end,\n",
    "    'on_validation_epoch_end': custom_validation_epoch_end, \n",
    "    'on_train_start': on_train_start, \n",
    "    'configure_optimizers': configure_optimizers\n",
    "}\n",
    "\n",
    "# Initialize the trainer with custom hooks\n",
    "trainer = LitTrainer(epochs=1, accelerator='cpu', patience=3, custom_training_step=custom_training_step, custom_hooks=custom_hooks)\n",
    "\n",
    "# Assuming `problem` and `data_setup_function` are defined\n",
    "trainer.fit(problem, data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "neuromancer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
